David Molony

Machine Learning, Data Science, Medical Imaging

Tensorflow image sequences


Background

This post will try to serve as a practical guide on how to import a sequence of images into tensorflow using the Dataset API. If you're not familiar with the Dataset API you should check out the tensorflow dataset guide. The beauty of the dataset API is that it allows us to import a datasource and use functions that apply image decoding, augmentation and batching all within the tensorflow environment. In this tutorial I will import a publicly available video datasource into tensorflow by first formatting and generating a TFRecords file and then reading this TFRecords file into tensorflow. This tutorial is designed for users with an intermediate level of familiarity with tensorflow

Working with images

A problem I've found using the Dataset API is that is does not play friendly with importing sequences of images. A search of StackOverflow shows others with the same issue. Often the easiest way to read images into tensorflow using the dataset API is through reading a text file where the image path is stored and generating a list of the filenames. Then we can generate a Dataset from tensor slices.

image_list, label_list = read_files(data_file)
data = tf.Data.Dataset_from_tensor_slices((image_list, label_list))
data_batch = data.batch(batch_size)

However, this will load one image at a time which we could batch to generate a 4D tensor but things start to get messy this way if we want to create a mini-batch from different video sources. To create a mini-batch of a sequence of images we essentially need a 5D tensor (batch, time, height, width, depth) as illustrated in the below figure. The easiest way to contruct this tensor is true TFRecords and I will show how this can be done.

title

Getting started

First we will identify a data source that requires us to work with sequences of images. UCF101 is an action recognition dataset of realistic action videos, collected from YouTube, having 101 action categories. With this dataset we might want to classify the action and incorporating several frames of the videos instead of a model that works on individual frames might boost our model accuracy.

Let's start by downloading and extracting the dataset

!wget http://crcv.ucf.edu/data/UCF50.rar
!unrar x UCF50.rar

Next we will import all the modules we need. Most of these are pretty common libraries but we also need to use cv2 in order to extract the frames from the avi video files. If running on linux you will need cv2>3.3.0

import os
import operator
import cv2
import PIL.Image as im
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

Extract frames

The UCF50 dataset contains .avi files for 50 actions. In order to use them for a machine learning algorithm we need to extract the frames. For each action there are multiple examples broken down into groups and clips. As we extract the frames we will write a text file for each video. This text file will contain the path of the extracted frames and the corresponding label for each action. We will use these files when generating our TFRecords file

In order to achieve all this we're are going to use some helper functions that
1. Extract frames from each video and write the frames as image files to the same folder
2. Create text files that contain the frame name for each frame we extract

As the UCF50 dataset is extremely large we are just going to look at a small subset of it. We will extract the frames for every 10th action. If you want to work with the entire dataset just delete the below line dirs=dirs[::10] but be aware there are over a million frames in total.

home_path = 'UCF50'
dirs = os.listdir(home_path)
dirs = [directory for directory in dirs if os.path.isdir(os.path.join(home_path, directory))]

# keep every 10th class
dirs = dirs[0::10]

# create a dictionary for each class label
frame_class = {}
for i, current_dir in enumerate(dirs):
    files = os.listdir(os.path.join(home_path, current_dir))
    files = [vid for vid in files if '.avi' in vid]

    # assign an integer value to represent each class
    frame_class[current_dir] = i + 1

    for vid in files:
        # get video file
        video = cv2.VideoCapture(os.path.join(home_path, current_dir, vid))

        count = 0
        success = True
        #print(os.path.join(home_path, vid[2:].split('.avi')[0])+ '.txt')
        f = open(os.path.join(home_path, vid[2:].split('.avi')[0]) + '.txt', 'a')

        while success:
            # write every 5th frame to disk
            if count % 5 == 0:
                # read the video
                success, image = video.read()
                # define the output filename
                out_path = os.path.join(home_path, current_dir, vid.split('.avi')[0] + "_frame%d.jpg" % count)
                if success:
                    # write frame as JPEG file
                    cv2.imwrite(out_path, image)     
                    # write text file with frame path and frame class
                    f.write(os.path.join(current_dir, vid.split('.avi')[0] + "_frame%d.jpg %d\n" % (count, i+1)))

                if cv2.waitKey(10) == 27:                     # exit if Escape is hit
                    f.close()
                    break
            count += 1
        f.close()

What are TFRecords?

Before I detail how the TFRecord file is generated I will give a brief explanation on TFRecords. TFRecords is a tensorflow standard format that stores your data as binary strings. TFRecords can either be tf.SequenceExample or tf.trainExample. For sequential data we use tf.SequenceExample. One thing to watch with TFRecords is that the filesize gets extremely large. I won't go into the details of tf.SequenceExample as Denny Britz (linked at the end of this article) has a great introduction on how to use tf.SequenceExample. The main thing that you will notice is that each sequence example contains a context and a feature_lists. The context contains non-sequential information about the sequence, while the feature_lists contains the sequence data e.g. words, images, stock prices. For instance for our sequence of images some context information could be the image height and width as well as the sequence length

Generate_tfrecord

Now that we know what the TFRecord file is we will generate it using the function generate_tfrecord. We have a text file for each video so we can make batches of image sequences. To make things easier we are going to truncate the total number of images in a batch to the shortest number. That is to say if in our mini-batch of 4 sequences the videos contain 150, 160, 170 and 180 frames we will truncate all sequences to 150 frames. Alternatively, we could use padding in cases where we have a greater number of images in one batch. The first helper function make_file_list reads all the text files and sorts each file/video by the number of frames so that we are never discarding a lot of frames. We have several other helper function for reading the images, decoding the images and converting the images to bytes.

The function generate_tfrecord does all the work. We start by calling tf.python_io.TFRecordWriter which is a class to write records to TFRecords file. We will write each sequence example to this object. Before we get to defining the tf.train.SequenceExample we will read in batch_size number of videos and add these to a list full_batch_image_list. Based on our prior determined sequence length we then read and convert images into a list of bytes using our helper function image_sequence. We similarly read the labels from the text file and convert this to the appropriate TFRecords format. Following this we create our context features and bundle these into a dictionary. Similary our sequence data is bundled into a dictionary sequence_dict. The variables sequence_context and sequence_list use tf.train.Features and tf.train.FeatureLists for the context and the squence data respectively. Finally our sequence example is generated using tf.train.SequenceExample which we then write. This is performed individually for each mini-sequence of our batch_size samples before we move onto writing the next mini-sequence of the same batch

# read filenames into list
def make_file_list(directory):
    """this function makes a list of sorted text files sorted by the sequence length"""
    files = os.listdir(directory)
    files = [name for name in files if '.txt' in name]
    file_dict = {}

    for file in files:
        f = open(os.path.join(directory, file), 'r')
        num_files = len(f.read().splitlines())
        file_dict[file] = num_files
        f.close()
    sorted_file = sorted(file_dict.items(), key=operator.itemgetter(1))

    return sorted_file

def read_files(directory, file):
    f = open(os.path.join(directory, file), 'r')
    data=f.read().splitlines()
    f.close()
    image_list = [name.split(' ')[0] for name in data]
    label_list = [int(name.split(' ')[1]) for name in data] ## FIX THIS BY CHANGING TO ONE SPACE 

    return image_list, label_list

def decode_images(image):
    """this function reads an image and converts it to a numpy array"""
    image = np.asarray(im.open(image.strip('\n')))

    return image

# create a image sequence where each entry is an image of bytes
def image_sequence(directory, image_list):
    """this function takes a list of images and returns the list in bytes"""
    image_bytes_list = []
    for image in image_list:
        image = decode_images(os.path.join(directory, image))
        image_bytes = image.tostring()
        image_bytes = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
        image_bytes_list.append(image_bytes)

    return image_bytes_list

def label_sequence(directory, label_list):
    """this function takes a list of labels and returns the list in int64"""
    label_int_list = []
    for label in label_list:
        label_int = tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
        label_int_list.append(label_int)

    return label_int_list

def generate_tfrecord(directory, file_list, batch_size, sequence_length, imsize, train_or_test):
    """file_list is a list of images directories sorted by nymber of images
    sequence_length is the length of the sequence
    train_or_test is a string prefix for the output file
    batch_size indicates the batch size
    imsize is a list that indicates the shape [nx, ny, ch]"""
    writer = tf.python_io.TFRecordWriter(train_or_test + '.tfrecord')

    image_height = imsize[0]
    image_width = imsize[1]
    image_depth = imsize[2]
    for i in range(0, len(file_list), batch_size):
        # read files
        j = i
        num_files = 1e6
        full_batch_image_list = []
        full_batch_label_list = []
        while j < i + batch_size:
            # get maximum number of files in each dataset
            num_files = min(num_files, file_list[j][1])
            image_list, label_list = read_files(directory, file_list[j][0])
            j += 1
            full_batch_image_list.append(image_list)
            full_batch_label_list.append(label_list)

        # iterate over batch lists and remove files greater than the max num_files
        full_batch_image_list = [batch[0:num_files] for batch in full_batch_image_list]
        full_batch_label_list = [batch[0:num_files] for batch in full_batch_label_list]

        # iterate over sequence length and add each batch so format is case1-t1, case1-t2, caes1-t3, case2-t1, case2-t2, caset-t3
        timesteps = num_files//sequence_length
        current_timestep = 0
        # TO HAVE LAST BATCH OF SIZE SEQUENCE LENGTH SET TO LESS THAN (<)
        # TO HAVE LAST BATCH OF SIZE (current_timestep % sequence_length) SET TO LESS THAN OR EQUAL TO (<=)
        while current_timestep < timesteps*sequence_length: 
            for l in range(batch_size):
                image_bytes_list = image_sequence(directory, full_batch_image_list[l][current_timestep:current_timestep+sequence_length])       
                label_int_list = label_sequence(directory, full_batch_label_list[l][current_timestep:current_timestep+sequence_length])

                case = full_batch_image_list[l][0].split('v_')[1].split('_frame')[0]
                # after every batch write the images to the tfrecord
                # create a feature list consisting of list of images
                images = tf.train.FeatureList(feature=image_bytes_list)
                labels = tf.train.FeatureList(feature=label_int_list)

                im_length = tf.train.Feature(int64_list=tf.train.Int64List(value=[len(image_bytes_list)]))
                im_height = tf.train.Feature(int64_list=tf.train.Int64List(value=[image_height]))
                im_width = tf.train.Feature(int64_list=tf.train.Int64List(value=[image_width]))
                im_depth = tf.train.Feature(int64_list=tf.train.Int64List(value=[image_depth]))
                im_name = tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(case)]))

                a = list(map(lambda x: x.split('_')[-1].split('.jpg')[0], full_batch_image_list[l][current_timestep:current_timestep+sequence_length]))
                a = "".join(a)
                frames = tf.train.Feature(bytes_list=tf.train.BytesList(value =[str.encode(a)]))

                # create a dictionary
                sequence_dict = {'Images': images, 'Labels': labels}
                context_dict = {'length': im_length, 'height': im_height, 'width': im_width, 'depth': im_depth, 'name': im_name, 'frames': frames}

                sequence_context = tf.train.Features(feature=context_dict)
                # now create a list of feature lists contained within dictionary
                sequence_list = tf.train.FeatureLists(feature_list=sequence_dict)

                example = tf.train.SequenceExample(context=sequence_context, feature_lists=sequence_list)

                writer.write(example.SerializeToString())

            # increment the timestep
            current_timestep += sequence_length
    writer.close()

Now that we have defined how we are going to create the TFRecords file, let's set the image height, width and depth as well as specifying the batch size and sequence length. The batch size and sequence length are extremely important as when we read the TFRecords file we will need to read it using the same batch size and sequence length. For this example we will use a batch size of 4 and a sequence length of 5. You can see that the outputted TFRecords file size is very large at 30GB.

# read in data files and divide into sequence
image_height = 240
image_width = 320
image_depth = 3
batch_size = 4
sequence_length = 5
sorted_files = make_file_list(home_path)
generate_tfrecord(home_path, sorted_files, batch_size, sequence_length, [image_height, image_width, image_depth], 'train')

Reading the TFRecords

With the TFRecords file generated we need to create a function that will read the TFRecords file and decode it for tensorflow. For this purpose we will create a class DataSequenceReader. This class will use the tensorflow Dataset API for
1. Reading the data
2. Applying functions to parse and augment the data
3. Batching the data

An instance of the class is initialized by specifying the batch size, sequence length and number of epochs. The function read_batch then reads the specified TFRecords file. Here I've included a simple function to rotate the sequence of images and you can see how any other augmentation technique could be employed.

class DataSequenceReader():
    def __init__(self, batch_size, sequence_length, num_epochs):
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.sequence_length = sequence_length

    def rotate_sequence(self, image, label, im_name, frames):
        """apply the same rotation to data sequence"""
        rot_angle = tf.random_uniform([], minval=0, maxval=360, dtype=tf.float32)

        for i in range(self.sequence_length):
            image = tf.contrib.image.rotate(image, rot_angle)

        return image, label, im_name, frames

    def parse_sequence(self, sequence_example):

        sequence_features = {'Images': tf.FixedLenSequenceFeature([], dtype=tf.string),
                          'Labels': tf.FixedLenSequenceFeature([], dtype=tf.int64)}

        context_features = {'length': tf.FixedLenFeature([], dtype=tf.int64),
                         'height': tf.FixedLenFeature([], dtype=tf.int64),
                         'width': tf.FixedLenFeature([], dtype=tf.int64),
                         'depth': tf.FixedLenFeature([], dtype=tf.int64),
                           'name': tf.FixedLenFeature([], dtype=tf.string),
                            'frames': tf.FixedLenFeature([], dtype=tf.string)}
        context, sequence = tf.parse_single_sequence_example(
            sequence_example, context_features=context_features, sequence_features=sequence_features)

        # get features context
        seq_length = tf.cast(context['length'], dtype = tf.int32)
        im_height = tf.cast(context['height'], dtype = tf.int32)
        im_width = tf.cast(context['width'], dtype = tf.int32)
        im_depth = tf.cast(context['depth'], dtype = tf.int32)
        im_name = context['name']
        frames = context['frames']

        # encode image
        image = tf.decode_raw(sequence['Images'], tf.uint8)
        image = tf.reshape(image, shape=(seq_length, im_height, im_width, im_depth))

        label = tf.cast(sequence['Labels'], dtype = tf.int32)

        return image, label, im_name, frames

    def read_batch(self, filename, train):
        dataset = tf.data.TFRecordDataset(filename)
        dataset = dataset.repeat()
        dataset = dataset.map(self.parse_sequence, num_parallel_calls=2)
        if train == 1:
            dataset = dataset.map(self.rotate_sequence, num_parallel_calls=2)
        dataset = dataset.batch(self.batch_size)

        return dataset

Reading and displaying the image sequences

Now that we have a class that will read our TFRecords file and produce a batch of image sequences along with their corresponding labels we will read these images and display them in a tensorflow session. We will run 5 steps of 5 frames each and you can observe the frames below. So, to recap we have a batch size of 4 and a sequence length of 5. Each displayed row is a sample of 5 successive images. After displaying the initial batch of size 4 we display the next batch of 4. This batch will be the next 5 frames for each of our samples. Hopefully this isn't too confusing but the images below will illustrate this process. This continues until we reach the next video in the TFRecords file (remember we truncated the sequences to have the same length)

# make a dataset iterator
data = DataSequenceReader(batch_size, sequence_length, num_epochs=10)
batch = data.read_batch('train.tfrecord', 0)
iterator = batch.make_initializable_iterator()

next_element = iterator.get_next()
batch_iterator = iterator.initializer
with tf.Session() as sess:
    sess.run(iterator.initializer)
    for _ in range(5):
        images, labels, name, frames = sess.run(next_element)
        frames = frames.astype(str)[0].split('frame')
        frames= [int(num) for num in frames if num]

        fig = plt.figure(figsize=(12,10))
        idx = 1
        for i in range(0, batch_size):
            print('Displaying frames {}'.format(frames))
            for j in range(0, sequence_length):
                # Display the frames along with the label by looking up the dictionary key
                ax = fig.add_subplot(batch_size, sequence_length, idx)
                ax.imshow(images[i, j, : ,: ,:])
                title = 'Class: {}'.format([key for key, value in frame_class.items() if value == labels[i, j]][0])
                ax.set_title(title)
                idx += 1
        plt.show()
Displaying frames [0, 5, 10, 15, 20]
Displaying frames [0, 5, 10, 15, 20]
Displaying frames [0, 5, 10, 15, 20]
Displaying frames [0, 5, 10, 15, 20]

png

Displaying frames [25, 30, 35, 40, 45]
Displaying frames [25, 30, 35, 40, 45]
Displaying frames [25, 30, 35, 40, 45]
Displaying frames [25, 30, 35, 40, 45]

png

Displaying frames [50, 55, 60, 65, 70]
Displaying frames [50, 55, 60, 65, 70]
Displaying frames [50, 55, 60, 65, 70]
Displaying frames [50, 55, 60, 65, 70]

png

Displaying frames [75, 80, 85, 90, 95]
Displaying frames [75, 80, 85, 90, 95]
Displaying frames [75, 80, 85, 90, 95]
Displaying frames [75, 80, 85, 90, 95]

png

Displaying frames [100, 105, 110, 115, 120]
Displaying frames [100, 105, 110, 115, 120]
Displaying frames [100, 105, 110, 115, 120]
Displaying frames [100, 105, 110, 115, 120]

png

And that's it! This formatting of the data will allow us to work with a batch of sequential images that we could use to train a recurrent neural network to learn some feature of the dataset. You can experiment yourself and change the sequence length but remember that you need to generate the TFRecords file again in order to work with the right batches. To recap the key things to take away are

  1. How to generate a TFRecords file using the tf.train.SequenceExample format where each entry is a sequence of images and labels
  2. How to use the tensorflow Dataset API to read a batch of image sequences

Further reading

https://medium.com/mostly-ai/tensorflow-records-what-they-are-and-how-to-use-them-c46bc4bbb564
www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/